{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b7e35746-4358-4eb2-bc79-667c96e5cb9a",
   "metadata": {},
   "source": [
    "# LoRA\n",
    "\n",
    "The goal of this practical is to adapt the code of [minGPT](https://github.com/karpathy/minGPT/) form [Karpathy](https://karpathy.ai/) in order to incorporate Low Rank Adaptation (LoRA) for fine-tuning.\n",
    "\n",
    "![](https://miro.medium.com/v2/resize:fit:720/format:webp/1*D_i25E9dTd_5HMa45zITSg.png)\n",
    "\n",
    "This [blog](https://r4j4n.github.io/blogs/posts/lora/) by [Rajan Ghimire](https://r4j4n.github.io/blogs/about/) is a nice introduction to LoRA."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62e99540-5b94-4b91-87ee-8cdfc1118dd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7497d1b9-fb9c-4507-a76d-0e662472db41",
   "metadata": {},
   "source": [
    "## Building a custom Linear module\n",
    "\n",
    "methods\n",
    "- [`forward`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.forward)\n",
    "- [`train`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.train)\n",
    "- [`eval`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval)\n",
    "- [`reset_parameters`](https://github.com/pytorch/pytorch/blob/v2.6.0/torch/nn/modules/linear.py#L50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0157cef-25c6-492b-a5e0-b2bebe56c28e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LoRALinear(nn.Linear):\n",
    "\n",
    "    def __init__(self,\n",
    "                 # nn.Linear parameters\n",
    "                 in_features: int,\n",
    "                 out_features: int,\n",
    "                 bias: bool = True,\n",
    "                 device=None,\n",
    "                 dtype=None,\n",
    "                 # LoRA parameters\n",
    "                 lora_rank: int = 0,\n",
    "                 lora_alpha: float = 0.0,\n",
    "                ) -> None:\n",
    "        nn.Linear.__init__(\n",
    "            self,\n",
    "            in_features=in_features,\n",
    "            out_features=out_features,\n",
    "            bias=bias,\n",
    "            device=device,\n",
    "            dtype=dtype\n",
    "        )\n",
    "\n",
    "        # LoRA stuff\n",
    "        self.has_weights_merged = False\n",
    "        if lora_rank > 0:\n",
    "            self.lora_scaling = lora_alpha / lora_rank\n",
    "            self.lora_A = nn.Parameter(torch.empty((lora_rank, self.in_features), device=device, dtype=dtype))\n",
    "            self.lora_B = nn.Parameter(torch.empty((self.out_features, lora_rank), device=device, dtype=dtype))\n",
    "\n",
    "            self.lora_A.requires_grad = False\n",
    "            self.lora_B.requires_grad = False\n",
    "\n",
    "            self.reset_parameters_lora()\n",
    "\n",
    "\n",
    "    def reset_parameters_lora(self) -> None:\n",
    "        ###\n",
    "        # your code here\n",
    "        ###\n",
    "\n",
    "    def forward(self, input: torch.Tensor) -> torch.Tensor:\n",
    "        x = nn.Linear.forward(self, input)\n",
    "        ###\n",
    "        # your code here\n",
    "        ###\n",
    "        return x\n",
    "\n",
    "    def train(self, mode: bool = True) -> \"LoRALinear\":\n",
    "        nn.Linear.train(self, mode)\n",
    "        ###\n",
    "        # your code here\n",
    "        ###\n",
    "        return self\n",
    "\n",
    "    def eval(self) -> \"LoRALinear\":\n",
    "        nn.Linear.eval(self)\n",
    "        ###\n",
    "        # your code here\n",
    "        ###\n",
    "        return self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4214d5ed-a19d-4d04-9a79-e2af5459bf69",
   "metadata": {},
   "outputs": [],
   "source": [
    "ln = LoRALinear(in_features=3,out_features=4, lora_rank = 8, lora_alpha = 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ce86682-6464-41c6-9378-fda7b39309db",
   "metadata": {},
   "outputs": [],
   "source": [
    "ln.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06f4ce4-7504-4dcf-b0c3-47d32a1c42bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "ln.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14060718-1134-4f9e-afbe-3e9614253d7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "for p in ln.parameters():\n",
    "    print(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1049ae8a-7c02-4f8d-a73f-8cbd06c57ef2",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 5\n",
    "x = torch.randn((bs, 3))\n",
    "y = ln(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4f33815-49a8-4615-bb50-0b0dde966fe6",
   "metadata": {},
   "outputs": [],
   "source": [
    "y2 = x@ln.weight.T + ln.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50b05904-7647-4d54-88a9-b95022e48905",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.isclose(y,y2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81d30d8a-2658-4a87-a615-fea2e5024df2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ln.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c0f25b0-ee1f-48c1-91c9-72fb3f288118",
   "metadata": {},
   "outputs": [],
   "source": [
    "y3 = ln(x)\n",
    "torch.isclose(y3,y2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8863303c-e4c3-47a8-b9bd-f1dedf347abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "ln.eval()\n",
    "y3 = ln(x)\n",
    "torch.isclose(y3,y2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2504cf76-424e-4309-b03d-5f62857c73da",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_lora_model(model: nn.Module) -> nn.Module:\n",
    "    for name, param in model.named_parameters():\n",
    "        if \"lora\" in name:\n",
    "            param.requires_grad = True\n",
    "        else:\n",
    "            param.requires_grad = False\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "465263bc-91a5-4c16-9917-6b19859d1ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ln_lora = get_lora_model(ln)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0747b55f-5b54-4be2-85b9-6d992c1fcaee",
   "metadata": {},
   "outputs": [],
   "source": [
    "for p in ln_lora.parameters():\n",
    "    print(p)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d012906-a5f8-40db-a58e-7342c6f3c486",
   "metadata": {},
   "source": [
    "## Use the LoRA layer in the building blocks of minGPT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dce4036-55f7-418e-a7b7-8995445452d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mingpt.model import CausalSelfAttention\n",
    "\n",
    "class CausalSelfAttention_LoRA(CausalSelfAttention):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        # minor modifications\n",
    "        self.c_attn = LoRALinear(\n",
    "            in_features=config.n_embd,\n",
    "            out_features=3 * config.n_embd,\n",
    "            lora_rank=config.lora_rank,\n",
    "            lora_alpha=config.lora_alpha,\n",
    "        )\n",
    "        # output projection\n",
    "        self.c_proj = LoRALinear(\n",
    "            in_features=config.n_embd,\n",
    "            out_features=config.n_embd,\n",
    "            lora_rank=config.lora_rank,\n",
    "            lora_alpha=config.lora_alpha,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8cc01ba-3eea-4eb7-a8a9-666294932b3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mingpt.model import Block, NewGELU\n",
    "\n",
    "class Block_LoRA(Block):\n",
    "    \"\"\" an unassuming Transformer block \"\"\"\n",
    "\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        # minor modification\n",
    "        self.attn = CausalSelfAttention_LoRA(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a1a3df7-ea4a-4d2f-802e-0b0ff9e480e8",
   "metadata": {},
   "source": [
    "Same thing for the GPT module and you can simplify the configuration of the optimizer for the LoRA module"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd4233f3-9fc6-4de8-ad7d-f2c78a838889",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mingpt.model import GPT\n",
    "\n",
    "class GPT_LoRA(GPT):\n",
    "    def __init__(self, config):\n",
    "        super().__init__(config)\n",
    "        self.transformer = nn.ModuleDict(dict(\n",
    "            wte = nn.Embedding(config.vocab_size, config.n_embd),\n",
    "            wpe = nn.Embedding(config.block_size, config.n_embd),\n",
    "            drop = nn.Dropout(config.embd_pdrop),\n",
    "            h = nn.ModuleList([Block_LoRA(config) for _ in range(config.n_layer)]),\n",
    "            ln_f = nn.LayerNorm(config.n_embd),\n",
    "        ))\n",
    "        self.config = config\n",
    "        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper\n",
    "        self.apply(self._init_weights)\n",
    "        for pn, p in self.named_parameters():\n",
    "            if pn.endswith('c_proj.weight'):\n",
    "                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))\n",
    "\n",
    "    def configure_optimizers(self, train_config):\n",
    "        \"\"\"\n",
    "        This long function is unfortunately doing something very simple and is being very defensive:\n",
    "        We are separating out all parameters of the model into two buckets: those that will experience\n",
    "        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).\n",
    "        We are then returning the PyTorch optimizer object.\n",
    "        \"\"\"\n",
    "\n",
    "        ###\n",
    "        # your code here\n",
    "        ###\n",
    "        \n",
    "        # separate out all parameters to those that will and won't experience regularizing weight decay\n",
    "        decay = set()\n",
    "        no_decay = set()\n",
    "        whitelist_weight_modules = (torch.nn.Linear, )\n",
    "        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)\n",
    "        for mn, m in self.named_modules():\n",
    "            for pn, p in m.named_parameters():\n",
    "                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name\n",
    "                # random note: because named_modules and named_parameters are recursive\n",
    "                # we will see the same tensors p many many times. but doing it this way\n",
    "                # allows us to know which parent module any tensor p belongs to...\n",
    "                if pn.endswith('bias'):\n",
    "                    # all biases will not be decayed\n",
    "                    no_decay.add(fpn)\n",
    "                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):\n",
    "                    # weights of whitelist modules will be weight decayed\n",
    "                    decay.add(fpn)\n",
    "                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):\n",
    "                    # weights of blacklist modules will NOT be weight decayed\n",
    "                    no_decay.add(fpn)\n",
    "\n",
    "        # validate that we considered every parameter\n",
    "        param_dict = {pn: p for pn, p in self.named_parameters()}\n",
    "        inter_params = decay & no_decay\n",
    "        union_params = decay | no_decay\n",
    "        assert len(inter_params) == 0, \"parameters %s made it into both decay/no_decay sets!\" % (str(inter_params), )\n",
    "        assert len(param_dict.keys() - union_params) == 0, \"parameters %s were not separated into either decay/no_decay set!\" \\\n",
    "                                                    % (str(param_dict.keys() - union_params), )\n",
    "\n",
    "        # create the pytorch optimizer object\n",
    "        optim_groups = [\n",
    "            {\"params\": [param_dict[pn] for pn in sorted(list(decay))], \"weight_decay\": train_config.weight_decay},\n",
    "            {\"params\": [param_dict[pn] for pn in sorted(list(no_decay))], \"weight_decay\": 0.0},\n",
    "        ]\n",
    "        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)\n",
    "        return optimizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ccf2f6e-0eb8-4d1c-95a8-88fab35e86fd",
   "metadata": {},
   "source": [
    "## Learning to sort\n",
    "\n",
    "We use the [demo](https://github.com/karpathy/minGPT/blob/master/demo.ipynb) to check that our code is running fine!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7710790-040c-4e02-8bc4-8ff792685e5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class Config:\n",
    "    n_head = 3\n",
    "    n_embd = 15\n",
    "    block_size = 11\n",
    "    # dropout hyperparameters\n",
    "    embd_pdrop = 0.1\n",
    "    resid_pdrop = 0.1\n",
    "    attn_pdrop = 0.1\n",
    "    # LoRA\n",
    "    lora_rank = 8\n",
    "    lora_alpha = 32\n",
    "\n",
    "# create a GPT instance\n",
    "model_config = GPT.get_default_config()\n",
    "model_config.model_type = 'gpt-nano'\n",
    "model_config.vocab_size = 3\n",
    "model_config.block_size = 100\n",
    "model_config.lora_rank = 8\n",
    "model_config.lora_alpha = 32\n",
    "\n",
    "model = GPT_LoRA(model_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "821f9cda-b6e8-48ed-9695-08ed244e0191",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "from mingpt.utils import set_seed\n",
    "set_seed(3407)\n",
    "import pickle\n",
    "\n",
    "class SortDataset(Dataset):\n",
    "    \"\"\" \n",
    "    Dataset for the Sort problem. E.g. for problem length 6:\n",
    "    Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2\n",
    "    Which will feed into the transformer concatenated as:\n",
    "    input:  0 0 2 1 0 1 0 0 0 1 1\n",
    "    output: I I I I I 0 0 0 1 1 2\n",
    "    where I is \"ignore\", as the transformer is reading the input sequence\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, split, length=6, num_digits=3):\n",
    "        assert split in {'train', 'test'}\n",
    "        self.split = split\n",
    "        self.length = length\n",
    "        self.num_digits = num_digits\n",
    "    \n",
    "    def __len__(self):\n",
    "        return 10000 # ...\n",
    "    \n",
    "    def get_vocab_size(self):\n",
    "        return self.num_digits\n",
    "    \n",
    "    def get_block_size(self):\n",
    "        # the length of the sequence that will feed into transformer, \n",
    "        # containing concatenated input and the output, but -1 because\n",
    "        # the transformer starts making predictions at the last input element\n",
    "        return self.length * 2 - 1\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \n",
    "        # use rejection sampling to generate an input example from the desired split\n",
    "        while True:\n",
    "            # generate some random integers\n",
    "            inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)\n",
    "            # half of the time let's try to boost the number of examples that \n",
    "            # have a large number of repeats, as this is what the model seems to struggle\n",
    "            # with later in training, and they are kind of rate\n",
    "            if torch.rand(1).item() < 0.5:\n",
    "                if inp.unique().nelement() > self.length // 2:\n",
    "                    # too many unqiue digits, re-sample\n",
    "                    continue\n",
    "            # figure out if this generated example is train or test based on its hash\n",
    "            h = hash(pickle.dumps(inp.tolist()))\n",
    "            inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test\n",
    "            if inp_split == self.split:\n",
    "                break # ok\n",
    "        \n",
    "        # solve the task: i.e. sort\n",
    "        sol = torch.sort(inp)[0]\n",
    "\n",
    "        # concatenate the problem specification and the solution\n",
    "        cat = torch.cat((inp, sol), dim=0)\n",
    "\n",
    "        # the inputs to the transformer will be the offset sequence\n",
    "        x = cat[:-1].clone()\n",
    "        y = cat[1:].clone()\n",
    "        # we only want to predict at output locations, mask out the loss at the input locations\n",
    "        y[:self.length-1] = -1\n",
    "        return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60b4ce35-e2bd-48a5-a4b0-b7628a1ee511",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print an example instance of the dataset\n",
    "train_dataset = SortDataset('train')\n",
    "test_dataset = SortDataset('test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f58d68e-b118-4378-803f-3b0a78a008b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = GPT.get_default_config()\n",
    "model_config.model_type = 'gpt-nano'\n",
    "model_config.vocab_size = train_dataset.get_vocab_size()\n",
    "model_config.block_size = 24 #train_dataset.get_block_size()\n",
    "model_config.lora_rank = 8\n",
    "model_config.lora_alpha = 32\n",
    "model_config.lora_dropout = 0\n",
    "model = GPT_LoRA(model_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e930e63-56da-48c3-a64b-270d5f2cd5a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a Trainer object\n",
    "from mingpt.trainer import Trainer\n",
    "\n",
    "train_config = Trainer.get_default_config()\n",
    "train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster\n",
    "train_config.max_iters = 1000#2000\n",
    "train_config.num_workers = 0\n",
    "trainer = Trainer(train_config, model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acbde1e6-3e44-4485-a6fd-cfb0d18c2910",
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_end_callback(trainer):\n",
    "    if trainer.iter_num % 100 == 0:\n",
    "        print(f\"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}\")\n",
    "trainer.set_callback('on_batch_end', batch_end_callback)\n",
    "\n",
    "trainer.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fcc07ff-1d1c-4421-b65f-7bbac0215dba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# now let's perform some evaluation\n",
    "model.eval();\n",
    "dataset = {'train':train_dataset, 'test':test_dataset}\n",
    "def eval_split(trainer, split, max_batches, dataset=dataset):\n",
    "    dataset = dataset[split]\n",
    "    n = dataset.length # naugy direct access shrug\n",
    "    results = []\n",
    "    mistakes_printed_already = 0\n",
    "    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)\n",
    "    for b, (x, y) in enumerate(loader):\n",
    "        x = x.to(trainer.device)\n",
    "        y = y.to(trainer.device)\n",
    "        # isolate the input pattern alone\n",
    "        inp = x[:, :n]\n",
    "        sol = y[:, -n:]\n",
    "        # let the model sample the rest of the sequence\n",
    "        cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling\n",
    "        sol_candidate = cat[:, -n:] # isolate the filled in sequence\n",
    "        # compare the predicted sequence to the true sequence\n",
    "        correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n",
    "        for i in range(x.size(0)):\n",
    "            results.append(int(correct[i]))\n",
    "            if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense\n",
    "                mistakes_printed_already += 1\n",
    "                print(\"GPT claims that %s sorted is %s but gt is %s\" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))\n",
    "        if max_batches is not None and b+1 >= max_batches:\n",
    "            break\n",
    "    rt = torch.tensor(results, dtype=torch.float)\n",
    "    print(\"%s final score: %d/%d = %.2f%% correct\" % (split, rt.sum(), len(results), 100*rt.mean()))\n",
    "    return rt.sum()\n",
    "\n",
    "# run a lot of examples from both train and test through the model and verify the output correctness\n",
    "with torch.no_grad():\n",
    "    train_score = eval_split(trainer, 'train', max_batches=50)\n",
    "    test_score  = eval_split(trainer, 'test',  max_batches=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb0f7f0f-3aa2-4039-8621-1aa1bf88fa19",
   "metadata": {},
   "source": [
    "Now we modifiy the distribution of the dataset a little bit and use LoRA to fine-tune."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "113f1a87-5d41-4d63-9b75-453cc69bac01",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset2 = SortDataset('train',length=10)\n",
    "test_dataset2 = SortDataset('test',length=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1986c07a-6b41-4cad-b147-c4fc7450f983",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset2 = {'train':train_dataset2, 'test':test_dataset2}\n",
    "with torch.no_grad():\n",
    "    train_score = eval_split(trainer, 'train', max_batches=50, dataset=dataset2)\n",
    "    test_score  = eval_split(trainer, 'test',  max_batches=50, dataset=dataset2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e55a597-1c86-4a70-9946-e67172c625ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# your code here for training with LoRA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76ad19cf-c8a4-4734-acd7-59a182c7a5cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval();\n",
    "with torch.no_grad():\n",
    "    train_score = eval_split(trainer, 'train', max_batches=50, dataset=dataset2)\n",
    "    test_score  = eval_split(trainer, 'test',  max_batches=50, dataset=dataset2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa9747c5-6f4f-44da-8321-3134b808be7f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dldiy",
   "language": "python",
   "name": "dldiy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
